/**
 * Copyright (c) 2024 Huawei Technologies Co., Ltd.
 * This file is a part of the CANN Open Software.
 * Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
 * Please refer to the License for details. You may not use this file except in compliance with the License.
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
 * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
 * See LICENSE in the root of the software repository for the full text of the License.
 */

#include <cstdio>
#include <cstdlib>
#include "aclnn_batch_matmul_custom.h"
#include "acl/acl_rt.h"
#include "acl/acl.h"
#include "../../../../../common/data_utils.h"

aclrtStream CreateStream(int32_t device)
{
    if (aclInit(nullptr) != ACL_SUCCESS) {
        printf("acl init failed\n");
        return nullptr;
    }
    if (aclrtSetDevice(device) != ACL_SUCCESS) {
        printf("Set device failed\n");
        CHECK_ACL(aclFinalize());
        return nullptr;
    }
    aclrtStream stream = nullptr;
    if (aclrtCreateStream(&stream) != ACL_SUCCESS) {
        printf("Create stream failed\n");
        CHECK_ACL(aclFinalize());
        return nullptr;
    }
    return stream;
}

void DestroyStream(aclrtStream stream, int32_t device)
{
    CHECK_ACL(aclrtDestroyStream(stream));
    if (aclrtResetDevice(device) != ACL_SUCCESS) {
        printf("Reset device failed\n");
    }
    if (aclFinalize() != ACL_SUCCESS) {
        printf("Finalize acl failed\n");
    }
}

void DestroyTensor(aclTensor *tensors[], void *devMem[], int32_t tensorCount)
{
    for (auto i = 0; i < tensorCount; i++) {
        if (!tensors[i]) {
            continue;
        }
        if (devMem[i]) {
            CHECK_ACL(aclrtFree(devMem[i]));
        }
        CHECK_ACL(aclDestroyTensor(tensors[i]));
    }
}

struct tensorInfo {
    int64_t *dims;
    int64_t dimCnt;
    aclDataType dtype;
    aclFormat fmt;
};

int64_t GetDataSize(struct tensorInfo *desc)
{
    if (!desc->dims) {
        return 0;
    }
    int64_t size = 1;
    for (auto i = 0; i < desc->dimCnt; i++) {
        size *= desc->dims[i];
    }
    return size * sizeof(float);
}

int64_t CompareResult(void *outputData, int64_t outSize)
{
    void *goldenData;
    CHECK_ACL(aclrtMallocHost((void**)(&goldenData), outSize));
    size_t goldenSize = outSize;
    bool ret = ReadFile("../output/golden.bin", goldenSize, goldenData, goldenSize);
    if (ret) {
        printf("ReadFile golden success!\n");
    } else {
        CHECK_ACL(aclrtFreeHost(goldenData));
        return -1;
    }
    constexpr float eps = 1e-3;
    int64_t wrongNum = 0;

    for (auto i = 0; i < outSize / sizeof(float); i++) {
        float a = ((float*)outputData)[i];
        float b = ((float*)goldenData)[i];
        float ae = std::abs(a - b);
        float re = ae / abs(b);
        if (ae > eps && re > eps) {
            printf("CompareResult failed output is %lf, golden is %lf\n", a, b);
            wrongNum++;
        }
    }
    CHECK_ACL(aclrtFreeHost(goldenData));
    return wrongNum;
}

int32_t main(void)
{
    // shape info
    int64_t shapeA[] = {192, 64};
    int64_t shapeB[] = {64, 1536};
    int64_t shapeBias[] = {1536};
    int64_t shapeC[] = {192, 256};
    struct tensorInfo tensorDesc[] = {{shapeA, 2, ACL_FLOAT16, ACL_FORMAT_ND},
                                      {shapeB, 2, ACL_FLOAT16, ACL_FORMAT_ND},
                                      {shapeBias, 1, ACL_FLOAT, ACL_FORMAT_ND},
                                      {shapeC, 2, ACL_FLOAT, ACL_FORMAT_ND},
    };
    aclrtStream stream = CreateStream(0);
    if (stream == nullptr) {
        return -1;
    }
    int32_t tensorCount = sizeof(tensorDesc) / sizeof(struct tensorInfo);
    aclTensor *tensors[tensorCount];
    void *devMem[tensorCount];
    constexpr int32_t aMatrixIndex = 0;
    constexpr int32_t bMatrixIndex = 1;
    constexpr int32_t biasIndex = 2;
    for (auto i = 0; i < tensorCount; i++) {
        void *data;
        struct tensorInfo *info = &(tensorDesc[i]);
        int64_t size = GetDataSize(info);
        if (size == 0) {
            tensors[i] = nullptr;
            devMem[i] = nullptr;
            continue;
        }
        CHECK_ACL(aclrtMalloc(&data, size, ACL_MEM_MALLOC_HUGE_FIRST));
        // read input
        if (i == aMatrixIndex) {
            size_t inputSize = size;
            void *dataHost;
            CHECK_ACL(aclrtMallocHost((void**)(&dataHost), inputSize));
            ReadFile("../input/input_a.bin", inputSize, dataHost, inputSize);
            CHECK_ACL(aclrtMemcpy(data, size, dataHost, size, ACL_MEMCPY_HOST_TO_DEVICE));
            CHECK_ACL(aclrtFreeHost(dataHost));
        }
        if (i == bMatrixIndex) {
            size_t inputSize = size;
            void *dataHost;
            CHECK_ACL(aclrtMallocHost((void**)(&dataHost), inputSize));
            ReadFile("../input/input_b.bin", inputSize, dataHost, inputSize);
            CHECK_ACL(aclrtMemcpy(data, size, dataHost, size, ACL_MEMCPY_HOST_TO_DEVICE));
            CHECK_ACL(aclrtFreeHost(dataHost));
        }
        if (i == biasIndex) {
            size_t inputSize = size;
            void *dataHost;
            CHECK_ACL(aclrtMallocHost((void**)(&dataHost), inputSize));
            ReadFile("../input/input_bias.bin", inputSize, dataHost, inputSize);
            CHECK_ACL(aclrtMemcpy(data, size, dataHost, size, ACL_MEMCPY_HOST_TO_DEVICE));
            CHECK_ACL(aclrtFreeHost(dataHost));
        }
        devMem[i] = data;
        tensors[i] = aclCreateTensor(info->dims, info->dimCnt, info->dtype, nullptr, 0, info->fmt, info->dims, info->dimCnt, data);
    }
    size_t workspaceSize = 0;
    aclOpExecutor *handle;
    int32_t ret = aclnnBatchMatmulCustomGetWorkspaceSize(tensors[0], tensors[1], tensors[2], tensors[3],
                                            &workspaceSize,
                                            &handle); // tensors_list is a, b, bias, c
    if (ret != ACL_SUCCESS) {
        printf("aclnnBatchMatmulCustomGetWorkspaceSize failed. error code is %u\n", ret);
        DestroyTensor(tensors, devMem, tensorCount);
        DestroyStream(stream, 0);
        return ret;
    }
    printf("aclnnBatchMatmulCustomGetWorkspaceSize ret %u workspace size %lu\n", ret, workspaceSize);
    void *workspace = nullptr;
    if (workspaceSize != 0) {
        CHECK_ACL(aclrtMalloc(&workspace, workspaceSize, ACL_MEM_MALLOC_HUGE_FIRST));
    }
    ret = aclnnBatchMatmulCustom(workspace, workspaceSize, handle, stream);
    printf("aclnnBatchMatmulCustom ret %u\n", ret);
    CHECK_ACL(aclrtSynchronizeStream(stream));

    uint8_t *zHost, *maxHost, *sumHost;
    int64_t zHostSize = GetDataSize(&(tensorDesc[3])); // 3 is c_index
    CHECK_ACL(aclrtMallocHost((void**)(&zHost), zHostSize));
    CHECK_ACL(aclrtMemcpy(zHost, zHostSize, devMem[3], zHostSize, ACL_MEMCPY_DEVICE_TO_HOST)); // 3 is c_index
    WriteFile("../output/output_z.bin", zHost, zHostSize);

    int64_t wrongNum = CompareResult(zHost, zHostSize);
    if (wrongNum != 0 ) {
        printf("test failed!\n");
    } else {
        printf("test pass!\n");
    }
    if (workspaceSize != 0) {
        CHECK_ACL(aclrtFree(workspace));
    }
    CHECK_ACL(aclrtFreeHost(zHost));

    DestroyTensor(tensors, devMem, tensorCount);
    DestroyStream(stream, 0);
    return 0;
}